model_config:
  class_name: tensorflow_asr.models.transducer.rnnt>RnnTransducer
  config:
    name: rnn_transducer
    speech_config:
      sample_rate: 16000
      frame_ms: 25
      stride_ms: 10
      num_feature_bins: 80
      nfft: 512
      feature_type: log_mel_spectrogram
      augmentation_config:
        feature_augment:
          time_masking:
            prob: 1.0
            num_masks: 5
            mask_factor: -1
            p_upperbound: 0.05
          freq_masking:
            prob: 1.0
            num_masks: 1
            mask_factor: 27
    encoder_reduction_positions: [ post, post, post, post ]
    encoder_reduction_factors: [ 3, 0, 2, 0 ] # downsampled to 30ms and add 2 reduction after second layer
    encoder_dmodel: 320
    encoder_rnn_type: lstm
    encoder_rnn_units: 1024
    encoder_nlayers: 4
    encoder_layer_norm: True
    prediction_label_encode_mode: embedding
    prediction_embed_dim: 512
    prediction_num_rnns: 1
    prediction_rnn_units: 1024
    prediction_rnn_type: lstm
    prediction_rnn_unroll: False
    prediction_layer_norm: True
    prediction_projection_units: 0
    joint_dim: 320
    prejoint_encoder_linear: True
    prejoint_prediction_linear: True
    postjoint_linear: False
    joint_activation: tanh
    joint_mode: add
    blank: 0
    vocab_size: {{decoder_config.vocabsize}}
    kernel_regularizer:
      class_name: l2
      config:
        l2: 1e-6

learning_config:
  optimizer_config:
    class_name: Adam
    config:
      learning_rate:
        class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
        config:
          dmodel: 320
          warmup_steps: 10000
          max_lr: null
          min_lr: 1e-6
          scale: 2.0
      beta_1: 0.9
      beta_2: 0.98
      epsilon: 1e-9
      weight_decay: 1e-6

  gwn_config:
    predict_net_step: 20000
    predict_net_stddev: 0.075

  gradn_config: null

  batch_size: 4
  ga_steps: 8
  num_epochs: 300

  callbacks:
    - class_name: tensorflow_asr.callbacks>TerminateOnNaN
      config: {}
    - class_name: tensorflow_asr.callbacks>ModelCheckpoint
      config:
        filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
        save_best_only: False
        save_weights_only: True
        save_freq: epoch
    - class_name: tensorflow_asr.callbacks>TensorBoard
      config:
        log_dir: {{modeldir}}/tensorboard
        histogram_freq: 0
        write_graph: False
        write_images: False
        write_steps_per_second: False
        update_freq: batch
        profile_batch: 0
    - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore
      config:
        model_handle: {{kaggle_model_handle}}
        model_dir: {{modeldir}}
        save_freq: epoch
